{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# default_exp models.RNN_FCN"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# RNN_FCN\n",
    "\n",
    "> This is an unofficial PyTorch implementation by Ignacio Oguiza - oguiza@gmail.com based on:\n",
    "\n",
    "* Karim, F., Majumdar, S., Darabi, H., & Chen, S. (2017). LSTM fully convolutional networks for time series classification. IEEE Access, 6, 1662-1669.\n",
    "* Official LSTM-FCN TensorFlow implementation: https://github.com/titu1994/LSTM-FCN\n",
    "\n",
    "* Elsayed, N., Maida, A. S., & Bayoumi, M. (2018). Deep Gated Recurrent and Convolutional Network Hybrid Model for Univariate Time Series Classification. arXiv preprint arXiv:1812.07683.\n",
    "* Official GRU-FCN TensorFlow implementation: https://github.com/NellyElsayed/GRU-FCN-model-for-univariate-time-series-classification"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "from tsai.imports import *\n",
    "from tsai.models.layers import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "class _RNN_FCN_Base(Module):\n",
    "    def __init__(self, c_in, c_out, seq_len=None, hidden_size=100, rnn_layers=1, bias=True, cell_dropout=0, rnn_dropout=0.8, bidirectional=False, shuffle=True, \n",
    "                 fc_dropout=0., conv_layers=[128, 256, 128], kss=[7, 5, 3], se=0):\n",
    "        \n",
    "        if shuffle: assert seq_len is not None, 'need seq_len if shuffle=True'\n",
    "            \n",
    "        # RNN\n",
    "        self.rnn = self._cell(seq_len if shuffle else c_in, hidden_size, num_layers=rnn_layers, bias=bias, batch_first=True, \n",
    "                              dropout=cell_dropout, bidirectional=bidirectional)\n",
    "        self.rnn_dropout = nn.Dropout(rnn_dropout) if rnn_dropout else noop\n",
    "        self.shuffle = Permute(0,2,1) if not shuffle else noop # You would normally permute x. Authors did the opposite.\n",
    "        \n",
    "        # FCN\n",
    "        assert len(conv_layers) == len(kss)\n",
    "        self.convblock1 = ConvBlock(c_in, conv_layers[0], kss[0])\n",
    "        self.se1 = SqueezeExciteBlock(conv_layers[0], se) if se != 0 else noop\n",
    "        self.convblock2 = ConvBlock(conv_layers[0], conv_layers[1], kss[1])\n",
    "        self.se2 = SqueezeExciteBlock(conv_layers[1], se) if se != 0 else noop\n",
    "        self.convblock3 = ConvBlock(conv_layers[1], conv_layers[2], kss[2])\n",
    "        self.gap = GAP1d(1)\n",
    "        \n",
    "        # Common\n",
    "        self.concat = Concat()\n",
    "        self.fc_dropout = nn.Dropout(fc_dropout) if fc_dropout else noop\n",
    "        self.fc = nn.Linear(hidden_size * (1 + bidirectional) + conv_layers[-1], c_out)\n",
    "        \n",
    "\n",
    "    def forward(self, x):  \n",
    "        # RNN\n",
    "        rnn_input = self.shuffle(x) # permute --> (batch_size, seq_len, n_vars) when batch_first=True\n",
    "        output, _ = self.rnn(rnn_input)\n",
    "        last_out = output[:, -1] # output of last sequence step (many-to-one)\n",
    "        last_out = self.rnn_dropout(last_out)\n",
    "        \n",
    "        # FCN\n",
    "        x = self.convblock1(x)\n",
    "        x = self.se1(x)\n",
    "        x = self.convblock2(x)\n",
    "        x = self.se2(x)\n",
    "        x = self.convblock3(x)\n",
    "        x = self.gap(x)\n",
    "\n",
    "        # Concat\n",
    "        x = self.concat([last_out, x])\n",
    "        x = self.fc_dropout(x)\n",
    "        x = self.fc(x)\n",
    "        return x\n",
    "            \n",
    "\n",
    "class RNN_FCN(_RNN_FCN_Base):\n",
    "    _cell = nn.RNN\n",
    "    \n",
    "class LSTM_FCN(_RNN_FCN_Base):\n",
    "    _cell = nn.LSTM\n",
    "    \n",
    "class GRU_FCN(_RNN_FCN_Base):\n",
    "    _cell = nn.GRU\n",
    "    \n",
    "class MRNN_FCN(_RNN_FCN_Base):\n",
    "    _cell = nn.RNN\n",
    "    def __init__(self, *args, se=16, **kwargs):\n",
    "        super().__init__(*args, se=16, **kwargs)\n",
    "    \n",
    "class MLSTM_FCN(_RNN_FCN_Base):\n",
    "    _cell = nn.LSTM\n",
    "    def __init__(self, *args, se=16, **kwargs):\n",
    "        super().__init__(*args, se=16, **kwargs)\n",
    "    \n",
    "class MGRU_FCN(_RNN_FCN_Base):\n",
    "    _cell = nn.GRU\n",
    "    def __init__(self, *args, se=16, **kwargs):\n",
    "        super().__init__(*args, se=16, **kwargs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "bs = 16\n",
    "n_vars = 3\n",
    "seq_len = 12\n",
    "c_out = 2\n",
    "xb = torch.rand(bs, n_vars, seq_len)\n",
    "test_eq(RNN_FCN(n_vars, c_out, seq_len)(xb).shape, [bs, c_out])\n",
    "test_eq(LSTM_FCN(n_vars, c_out, seq_len)(xb).shape, [bs, c_out])\n",
    "test_eq(MLSTM_FCN(n_vars, c_out, seq_len)(xb).shape, [bs, c_out])\n",
    "test_eq(GRU_FCN(n_vars, c_out, shuffle=False)(xb).shape, [bs, c_out])\n",
    "test_eq(GRU_FCN(n_vars, c_out, seq_len, shuffle=False)(xb).shape, [bs, c_out])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "LSTM_FCN(\n",
       "  (rnn): LSTM(2, 100, batch_first=True)\n",
       "  (rnn_dropout): Dropout(p=0.8, inplace=False)\n",
       "  (convblock1): ConvBlock(\n",
       "    (0): Conv1d(3, 128, kernel_size=(7,), stride=(1,), padding=(3,), bias=False)\n",
       "    (1): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "    (2): ReLU()\n",
       "  )\n",
       "  (se1): SqueezeExciteBlock(\n",
       "    (avg_pool): GAP1d(\n",
       "      (gap): AdaptiveAvgPool1d(output_size=1)\n",
       "      (flatten): Flatten(full=False)\n",
       "    )\n",
       "    (fc): Sequential(\n",
       "      (0): Linear(in_features=128, out_features=16, bias=False)\n",
       "      (1): ReLU()\n",
       "      (2): Linear(in_features=16, out_features=128, bias=False)\n",
       "      (3): Sigmoid()\n",
       "    )\n",
       "  )\n",
       "  (convblock2): ConvBlock(\n",
       "    (0): Conv1d(128, 256, kernel_size=(5,), stride=(1,), padding=(2,), bias=False)\n",
       "    (1): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "    (2): ReLU()\n",
       "  )\n",
       "  (se2): SqueezeExciteBlock(\n",
       "    (avg_pool): GAP1d(\n",
       "      (gap): AdaptiveAvgPool1d(output_size=1)\n",
       "      (flatten): Flatten(full=False)\n",
       "    )\n",
       "    (fc): Sequential(\n",
       "      (0): Linear(in_features=256, out_features=32, bias=False)\n",
       "      (1): ReLU()\n",
       "      (2): Linear(in_features=32, out_features=256, bias=False)\n",
       "      (3): Sigmoid()\n",
       "    )\n",
       "  )\n",
       "  (convblock3): ConvBlock(\n",
       "    (0): Conv1d(256, 128, kernel_size=(3,), stride=(1,), padding=(1,), bias=False)\n",
       "    (1): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "    (2): ReLU()\n",
       "  )\n",
       "  (gap): GAP1d(\n",
       "    (gap): AdaptiveAvgPool1d(output_size=1)\n",
       "    (flatten): Flatten(full=False)\n",
       "  )\n",
       "  (concat): Concat(1)\n",
       "  (fc): Linear(in_features=228, out_features=12, bias=True)\n",
       ")"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "LSTM_FCN(n_vars, seq_len, c_out, se=8)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#hide\n",
    "out = create_scripts()\n",
    "beep(out)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
